Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/Feat (trunc avg pool): Update truncation and average pool behaviour #1042

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from

Conversation

nickfraser
Copy link
Collaborator

@nickfraser nickfraser commented Oct 4, 2024

Proposal: Changes to TruncQuant and TruncAvgPool

Given that the behaviour of TruncQuant and TruncAvgPool have existed for many years, and some examples do rely on them, I think it's worth opening this to be a discussion with the community.

Note, as far as I can tell, the only place TruncQuant is used with TruncAvgPool and TruncQuantAccumulator, but I think the issues can be understood through the lens of TruncAvgPool.

Motivation

I find the current implementation of TruncQuant and TruncAvgPool to be a little troubling and incorrect outside of a specific use-case. I'll explain further down in this document.

Current Implementation / Status

Before explaining my issues with the current implementation, let me first explain how TruncQuant and TruncAvgPool currently work.

Current TruncQuant Implementation

TruncQuant shifts the integer representation of the input left or right (i.e., multiply by a power-of-two) proportional to the difference in the input and output bit width. That new integer value is then reinterpreted using the same scale as the input.

Example:

import torch
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.bit_width import BitWidthConst
from brevitas.core.quant.int import TruncIntQuant

x = torch.tensor([127])
tq = TruncIntQuant(
    bit_width_impl=BitWidthConst(4),
    float_to_int_impl=RoundSte(),
)
y, scale, zp, bw = tq(x, scale=1, zero_point=0, input_bit_width=8)
print(f"({y}, {scale}, {zp}, {bw})")

Produced the following output:

(tensor([8.]), 1, 0, 4.0)

I.e., if two's-complement is used, the binary representation of the underlying data has changed from 0x01111111 to 0x1000, but since the scale factor remains the same, the interpretation of these values changes from 127 -> 8. Effectively, the operation divides the value of the input by 2**(input_bit_width - output_bitwidth), while also rounding the result. Also note, the most-significant bits of the input 8-bit type are always kept (possibly with rounding).

Finally, there is no guard against overflowing, if instead x = torch.tensor([255]) in the above example, the output is:

(tensor([16.]), 1, 0, 4.0)

Which is an invalid IntQuantTensor, since the value 16 cannot be represented in 4-bits (if scale=1, zero_point=0).

Current TruncAvgPool Implementation

This needs to split into 2 sections, since one part is the behaviour of torch.nn.functional.AvgPool2d with IntQuantTensor input, afterwards, the output of this functional call is modified and passed into a TruncQuant module.

IntQuantTensor and torch.nn.functional

When an IntQuantTensor is passed to torch.nn.functional.AvgPool2d, effectively the scale, zero_point parameters are ignored. The bit_width field is updated to match the total number of bits required to represent the sum which underlies the average operation. The value field of IntQuantTensor is simply passed to torch.nn.functional.AvgPool2d and passed onto the result.

For example, the following code:

import torch

import torch.nn.functional as F
from brevitas.quant_tensor import IntQuantTensor
from brevitas.nn import QuantIdentity

torch.manual_seed(0)

k=3
x = torch.rand((1,1,k,k))
qi = QuantIdentity(return_quant_tensor=True, bit_width=8)

qi.train()
qi(x)
qx = qi(x) # Create valid `IntQuantTensor`

y = F.avg_pool2d(qx, kernel_size=k, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

print(qx)
print(y)
print(y.value / y.scale) # Should be integers for a valid `IntQuantTensor`

produces the following output:

IntQuantTensor(value=tensor([[[[0.4972, 0.7704, 0.0910],
          [0.1331, 0.3082, 0.6373],
          [0.4902, 0.8894, 0.4552]]]], grad_fn=<MulBackward0>), scale=0.00700347451493144, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=True)
IntQuantTensor(value=tensor([[[[0.4747]]]], grad_fn=<AvgPool2DBackward0>), scale=0.00700347451493144, zero_point=0.0, bit_width=12.0, signed_t=True, training_t=True)
tensor([[[[67.7778]]]], grad_fn=<DivBackward0>)

Note that the value field of the output is no longer an integer multiple of the scale, meaning that this operations does not produce a valid IntQuantTensor.

Scaling and Passing to TruncQuant

Before the output of the torch.nn.functional.AvgPool2d call is passed to TruncQuant the value field of the intermediate IntQuantTensor is multiplied by the value of the denominator in the average calculation - effectively turning the intermediate result from a AvgPool to a SumPool. This set corrects the issue between the value and scale from the intermediate output and converts this into a valid IntQuantTensor again.

This intermediate result is then passed to TruncQuant and the truncation / reinterpretation process described in the previous section is performed.

Overall, when the input and output bitwidths to a TruncAvgPool operation are the same, something somewhat sane occurs, for example:

import torch

from brevitas.nn import TruncAvgPool2d, QuantIdentity

torch.manual_seed(0)

k=3
x = torch.rand((1,1,k,k))

qi = QuantIdentity(return_quant_tensor=True, bit_width=8)
qap = TruncAvgPool2d(kernel_size=(k,k), bit_width=8, float_to_int_impl_type='round', return_quant_tensor=True)

qi.train()
qi(x)
qi.eval()
qap.eval()
qx = qi(x)

y = qap(qx)

print(qx)
print(y)
print(y.value / y.scale) # Should be integers for a valid `IntQuantTensor`

Which produces the following output:

IntQuantTensor(value=tensor([[[[0.4972, 0.7704, 0.0910],
          [0.1331, 0.3082, 0.6373],
          [0.4902, 0.8894, 0.4552]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False)
IntQuantTensor(value=tensor([[[[0.2661]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False)
tensor([[[[38.]]]])

Note, that since the same random seed is used from the previous example, the output here IntQuantTensor(value=tensor([[[[0.2661]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=8.0, ... is directly comparable to the output in the previous example IntQuantTensor(value=tensor([[[[0.4747]]]], ..., scale=0.00700347451493144, zero_point=0.0, bit_width=12.0, ....
Note, that the bit_width has reduced (12 to 8) as expected, while the scale factor remains the same in both. The dequantized value has scaled approximately proportionally to k**2 / 16, where k**2 is the denominator of the average that occurred in the average pool, while 16 is 2**(input_bit_width - output_bitwidth) as described in the TruncQuant section. An argument can be made that this output is a desired one, even though the dequantized value significantly differs from the expected (unquantized) one.

However, note that the input and output bitwidths to the TruncAvgPool are identical. If the bitwidth of TruncAvgPool is set to 12, the output becomes:

IntQuantTensor(value=tensor([[[[0.4972, 0.7704, 0.0910],
          [0.1331, 0.3082, 0.6373],
          [0.4902, 0.8894, 0.4552]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False)
IntQuantTensor(value=tensor([[[[4.2721]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=12.0, signed_t=True, training_t=False)
tensor([[[[610.]]]])

A side-effect of the above functionality has converted the average pool to a sum pool. Conversely, if the bitwidth of TruncAvgPool is set to 4, the output becomes:

IntQuantTensor(value=tensor([[[[0.4972, 0.7704, 0.0910],
          [0.1331, 0.3082, 0.6373],
          [0.4902, 0.8894, 0.4552]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=8.0, signed_t=True, training_t=False)
IntQuantTensor(value=tensor([[[[0.0140]]]]), scale=0.00700347451493144, zero_point=0.0, bit_width=4.0, signed_t=True, training_t=False)
tensor([[[[2.]]]])

Now, the dequantized output is approximately proportional to k**2 / 256 to the expected (unquantized) value since the TruncQuant module is applying a much larger dividing factor.

Takeaways

When input_bitwidth==output_bitwidth in a TruncAvgPool layer, we get somewhat sane functionality, but outside of this, I consider this behaviour at best "unintuitive" and at worst "buggy". Also, the sane behaviour for input_bitwidth==output_bitwidth seems to break several rules implicit with the Brevitas codebase, specifically:

  • The dequantized output of a quantizer should be close to the input value
  • If a operation produces a *QuantTensor that QuantTensor should have valid data
  • An average pool should produce something close to an average of its inputs, unless the divisor_override parameter is used

Proposal

In order to correct these issues, I propose the following:

  • The scale should be adjusted after an AvgPool functional call so that a valid IntQuantTensor is produced
  • The TruncQuant operator should not reinterpret it's output, instead its scale should be adjusted by the amount of truncation that has occurred
  • Add a clamping operator to TruncQuant to avoid overflow / underflow
  • The average pool shouldn't manipulate IntQuantTensors - this should be handled by the functional call and TruncQuant
  • The previous behaviour should be retained when input_bitwidth==output_bitwidth when the divisor_override parameter is used as 2**math.ceil(math.log2(k*k)). Other scenarios require careful manipulation of divisor_override to be achieved.

Furthermore, I question the usefulness in hardware for the current TruncAvgPool for the following reasons:

  • Forcing a power-of-two denominator for the average is only beneficial on hardware that only supports power-of-two scaling (otherwise the denominator can be absorbed and streamlined in FINN, for example)
  • Requiring that the most-significant bits after a summation are kept is likely not optimal unless the input values are all very high in magnitude

So finally, I recommend:

  • Introducing a QuantAvgPool layer, which can be instantiated with any sensible activation quantizer
  • Creating a default activation quantizer which only allows power-of-2 shifts in scale W.R.T. the output of the AvgPool functional call

@nickfraser nickfraser self-assigned this Oct 4, 2024
@nickfraser
Copy link
Collaborator Author

CC @volcacius.

@nickfraser
Copy link
Collaborator Author

CC @auphelia.

@Giuseppe5 Giuseppe5 added the next release PRs which should be merged for the next release label Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
next release PRs which should be merged for the next release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants